import torch
import torch.nn as nn
from torchvision.transforms.functional import to_tensor
class BaseConditioner(nn.Module):
    def __init__(self):
        super(BaseConditioner, self).__init__()
        
    def _impl_condition(self, y):
        ...
    def _impl_uncondition(self, y):
        ...
    def __call__(self, y):
        condition = self._impl_condition(y)
        uncondition = self._impl_uncondition(y)
        return condition, uncondition
import numpy as np
class LabelConditioner(BaseConditioner):
    def __init__(self, empty_npy):
        super().__init__()
        self.empty_npy = to_tensor(np.load(empty_npy)).cuda()
    def _impl_condition(self, y):
        return y

    def _impl_uncondition(self, y):
        return self.empty_npy.expand(y.size(0), -1,-1)